# -*-Python-*-

# The vanilla transformer model but with transparent attention between the encoder and
# decoder. See https://arxiv.org/abs/1808.07561.


import mesh_tensorflow.transformer.transformer_layers

encoder/transformer.make_layer_stack.layers = [
    @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
    @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
]
decoder/transformer.make_layer_stack.layers = [
    @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
    @mesh_tensorflow.transformer.transformer_layers.TransparentEncDecAttention,
    @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
]

TransparentEncDecAttention.layers_per_encoder_module = 2
TransparentEncDecAttention.layers_per_decoder_module = 3
TransparentEncDecAttention.dropout_rate = %dropout_rate
TransparentEncDecAttention.encoder_num_modules = %num_layers
TransparentEncDecAttention.decoder_num_modules = %num_layers

